#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Aug 2 2021
Using Hg benchmark files to compare seasonal cycles of different runs with Hg0 dry deposition/reduction settings
@author: arifeinberg
"""

#import os
#os.chdir('/Users/arifeinberg/target2/fs03/d0/arifein/python/pythonHgBenchmark')
import matplotlib
matplotlib.use('Agg')

from helper_functions import open_Hg
import xarray as xr
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from SiteLevels import levels
from helper_functions import ds_sel_yr

#%% Plot of seasonal cycle for different lat regions
def Seasonal_Lat_Regions_List(ds_list, ds_names, ds_num, Year = None):
    """Plot observational seasonal cycle against the model for different 
    latitudes (Southern Mid Latitiude, North Mid Latitude,
    Arctic, Antarctic).
    
    Parameters
    ----------
    ds_list : list of xarray datasets
         Model datasets
    ds_names : list of strings
        Names of model simulations, for plotting 
    ds_num : list of strings
        Numbers of model simulations, for vars 
    
    Year : int or list of int, optional
        Optional parameter to only select subset of years    
    
    """

    # Import the observed data from the sites     
    Hgobs = pd.read_csv('data/TGMSiteMonthly.csv',  skiprows=[0], na_values=(-9999))
    Hgobs.columns=['SiteID', 'Lat', 'Lon','Month', 'Year', 'Concentration', 'Standard deviation']
 
    # Make arrays of SiteIDs for the Arctic, Antarctic and Northern and Souther Mid Latitudes, extracting the sites
    # in the data sets. 
    Arctic = ['ALT', 'VRS', 'ZEP', 'AND', 'PAL','AMD']
    Arctic = [e for e in Arctic if e in list(Hgobs.SiteID)]
    
    SouthMidLat=['CPT', 'AMS', 'BAR']
    SouthMidLat = [e for e in SouthMidLat if e in list(Hgobs.SiteID)]
    
    Antarctic= ['TRO', 'DDU', 'DMC']
    Antarctic = [e for e in Antarctic if e in list(Hgobs.SiteID)]
    
    NorthMidLat= ['MHD', 'UDH', 'KEJ',  'HTW', 'PNY', 'ATN', 'YKV', 'GRB','BIR', 'WAL', 'BRA', 'SAT', 'THOMPFARM', 'SCO', 'STIWELL', 'EBG'] 
    NorthMidLat = [e for e in NorthMidLat if e in list(Hgobs.SiteID)]
    
    # Calculate mean, std for model and obs in each region
    
    # Arctic
    #Arc_df = filter_sites_region_list(Arctic, Hgobs, ds_list, ds_num, Year)
    # Antarctic                  
    #Ant_df = filter_sites_region_list(Antarctic, Hgobs, ds_list, ds_num, Year)
    # Northern Mid Latitudes                  
    NML_df = filter_sites_region_list(NorthMidLat, Hgobs, ds_list, ds_num, Year)
    # Southern Mid Latitudes
    SML_df = filter_sites_region_list(SouthMidLat, Hgobs, ds_list, ds_num, Year)

    # Create a list of all regions for looped plots
    #Regions_df_all = [Arc_df, Ant_df, NML_df, SML_df] 
    #Region_names = ['Arctic (68-83 °N)', 'Antarctic (67-75 °S)','Northern Mid Latitudes (30-53 °N)', 'Southern Mid Latitudes (34-41 °S)']
    Regions_df_all = [NML_df, SML_df] 
    Region_names = ['NH Midlatitudes (30-53 °N)','SH Midlatitudes (34-41 °S)']
     
    # Plot the one graphs as subplot.
    RegPlot,  ax = plt.subplots(1, 2, figsize=[18,9],
                                    gridspec_kw=dict(hspace=0.3, wspace=0.4))
    plt.subplots_adjust(right=0.75)
    axes = ax.flatten()

    # # Plot the four graphs as subplots.
    # RegPlot,  axes = plt.subplots(2, 2, figsize=[16,12],
    #                                 gridspec_kw=dict(hspace=0.3, wspace=0.2))

    # axes = axes.flatten()
    
    # Loop over regions and plot
    for ii, iax in enumerate(axes):
        print(Region_names[ii])
        Reg_df = Regions_df_all[ii] # dataframe with data
        # Plot the observations and their error.
        iax.errorbar(Reg_df.index, Reg_df['Obs_mean'], yerr=Reg_df['Obs_std'], 
                      color='k', capsize=2, linewidth=4, label='Obs')
        # print seasonal amplitude
        print('Obs')
        print(max(Reg_df['Obs_mean']) - min(Reg_df['Obs_mean']))
        print(Reg_df['Obs_mean'].mean())
        # Plot the models on the same graph with their errors.
        #colors = ['#1f77b4','#2ca02c','#d62728','#d62728','#d62728']
        #linestyles = ['solid','solid','solid','dashed','dotted']
        colors = ['#e41a1c','#377eb8', '#377eb8', '#b2df8a', '#33a02c','#984ea3','#beaed4']
        linestyles = ['solid','dotted','solid','solid','solid','solid','solid']
        lws = [4,4,4,4,4,4,4]

        for jj in range(len(ds_list)):
            # automatic colors
            # iax.errorbar(Reg_df.index, Reg_df[ds_num[jj] + '_mean'], 
            #              yerr=Reg_df[ds_num[jj] + '_std'], 
            #              capsize=4)
            
            # choose colors/linestyle based on simulation
            # print seasonal amplitude
            print(ds_names[jj])
            print(max(Reg_df[ds_num[jj] + '_mean']) - min(Reg_df[ds_num[jj] + '_mean']))
            print(Reg_df[ds_num[jj] + '_mean'].mean())

            if (jj ==0):
             # iax.plot(Reg_df.index, Reg_df[ds_num[jj] + '_mean'], 
             #               color=colors[jj], linestyle=linestyles[jj],
             #               linewidth=lws[jj], zorder=10)
              iax.errorbar(Reg_df.index, Reg_df[ds_num[jj] + '_mean'], 
                            yerr=Reg_df[ds_num[jj] + '_std'], 
                            color=colors[jj], linestyle=linestyles[jj],
                            capsize=2, linewidth=lws[jj], label=ds_names[jj])
            else:
             # iax.errorbar(Reg_df.index, Reg_df[ds_num[jj] + '_mean'], 
             #               yerr=Reg_df[ds_num[jj] + '_std'], 
             #               color=colors[jj], linestyle=linestyles[jj],
             #               capsize=2, linewidth=lws[jj])
              iax.plot(Reg_df.index, Reg_df[ds_num[jj] + '_mean'],    
                            color=colors[jj], linestyle=linestyles[jj],
                            linewidth=lws[jj], label=ds_names[jj])


        # Label the x and y axis. 
        iax.set_xlabel('Month', fontsize=25)
        iax.set_ylabel('Total Gaseous Hg (ng/m$^3$)', fontsize=25)
        
        # Add a title.
        iax.set_title(Region_names[ii], fontsize=27, fontweight='bold')
        # Set ticks to every month 
        iax.set_xticks(Reg_df.index)
        # Set tick labels to month names
        mn = ['J','F','M','A','M','J','J','A','S','O','N','D']
        iax.set_xticklabels(mn, fontsize=23)
        iax.tick_params(axis='y', labelsize=23 )        

       
    handles,labels = axes[0].get_legend_handles_labels()
    handles = [handles[6],handles[7],handles[0],handles[1],handles[2],handles[3],handles[4],handles[5]]
    labels = [labels[6],labels[7],labels[0],labels[1],labels[2],labels[3],labels[4],labels[5]]
    RegPlot.legend(handles,labels,loc = 'center right', fontsize=25) 

    return RegPlot

def filter_sites_region_list(Region, Hgobs, ds_list, ds_num, Year = None):
    """ Calculate the regional seasonal cycle in observations and the two model simulations.
    
    Parameters
    ----------
    Region : list
        List of strings giving the site names to average over    
    Hgobs : DataFrame
        Observational dataset for seasonal cycle
    ds_list : list of xarray datasets
         Model datasets
    ds_num : list of strings
        Numbers of model simulations, for vars          
    """

    # Calculate constant for the unit conversion factor from vmr to  ng/m^3
    R = 8.314462 # m^3 Pa K^-1 mol ^-1
    MW_Hg = 200.59 # g mol^-1
    ng_g = 1e9 # ng/g
    
    stdpressure = 101325 # Pascals
    stdtemp = 273.15 # Kelvins
    
    unit_conv = stdpressure / R / stdtemp * MW_Hg * ng_g # converter from vmr to ng m^-3

        
    # Extract the data from each site in the region, creating a new DataFrame
    for i, isite in enumerate(Region):
      site_df = Hgobs[Hgobs.SiteID==isite].reset_index()
      if i==0:
          All_region_df = site_df
      else:
          All_region_df = pd.concat([All_region_df, site_df])
          
    # Calculate the mean and stanadard deviation of observations for each month.
    obs_mean = np.asarray(All_region_df.groupby('Month').mean().Concentration)
    obs_std = np.asarray(All_region_df.groupby('Month').std().Concentration)
    
    # initialize dictionary and fill it
    data_dic = dict()
    data_dic['Obs_mean'] = obs_mean
    data_dic['Obs_std'] = obs_std
    
    # Select all unique latitudes and longitudes from the dataset.
    obs_lat = All_region_df.Lat.unique()
    obs_lon = All_region_df.Lon.unique()
    
    # print(obs_lat.min())
    # print(obs_lat.max())
    
    # initialize arrays
    ds_all_mean = np.zeros((len(ds_list),12), float)
    ds_all_std = np.zeros((len(ds_list),12), float)
    
    # Loop over the datasets
    for j, jds in enumerate(ds_list):
        # Allow subsetting for years of the simulation, if inputted into the function
       
        j_Hg0_yr = ds_sel_yr(jds, 'SpeciesConc_Hg0', Year)
        j_Hg2_yr = ds_sel_yr(jds, 'SpeciesConc_Hg2', Year)
    
        # Create datasets for seasonal TGM at each site for the ref and new models     
        for i in range (len(Region)): 
          j_Hg0_site = j_Hg0_yr.isel(lev=levels(Region[i])).\
              sel(lat=[obs_lat[i]], lon=[obs_lon[i]], method='nearest').squeeze()
          j_Hg2_site = j_Hg2_yr.isel(lev=levels(Region[i])).\
              sel(lat=[obs_lat[i]], lon=[obs_lon[i]], method='nearest').squeeze()
              
          # Calculate TGM values as sum of Hg0 and Hg2     
          Reg_j_mod = (j_Hg0_site + j_Hg2_site) * unit_conv
    
          # calculate climatology (needed if more than one year are averaged)
          Reg_j_clim = Reg_j_mod.groupby('time.month').mean() 
       
          if i==0:
              Reg_DS_j = Reg_j_clim
          else: # concatenate site values together
              Reg_DS_j = xr.concat([Reg_DS_j,Reg_j_clim], dim='concat_dims')
        
        #sys.tracebacklimit = 1
        #raise ValueError()

        # Calculate the mean and standard deviations for the reference and new models.
        ds_all_mean[j, :] = np.asarray(Reg_DS_j.mean('concat_dims'))
        ds_all_std[j, :] = np.asarray(Reg_DS_j.std('concat_dims'))
        
        # Extend dictionary with data
        data_dic[ds_num[j] + '_mean'] = ds_all_mean[j, :]
        data_dic[ds_num[j] + '_std'] = ds_all_std[j, :]
           
    # Save results in a Pandas DataFrame
    #data_dic = {'Obs_mean': obs_mean,'Obs_std': obs_std, 
    #            'OLD_mean': ds_all_std[0],'OLD_std': ds_all_std[0]}
    Out_df = pd.DataFrame(data_dic)
    
    return Out_df
#%% Opening Hg species datasets

#runs = ['0005','0016','0017','0018','0019']
runs = ['0005','0016','0017','0018','0019','0102','0105']

fns = [None] * len(runs)
year_to_analyze = [None] * len(runs)
ds_list = [None] * len(runs)
fns_CHC = [None] * len(runs)
ds_list_CHC = [None] * len(runs)

for i in range(len(runs)):
    fns[i] =  '../../GEOS-Chem_runs/run' + runs[i] + '/OutputDir/GEOSChem.SpeciesConc.alltime_m.nc4'
    year_to_analyze[i] = 2015 # year to analyze from simulations
    ds_list[i] = xr.open_dataset(fns[i])

ds_names = ['BASE', 'OBRIST','OBRIST_R','AMAZON_L', 'AMAZON_U','NEWCHEM','NEWCHEM_D']
#ds_names = ['BASE', 'AMAZON_L', 'AMAZON_H']

figure_sites = Seasonal_Lat_Regions_List(ds_list, ds_names, runs, year_to_analyze)

figure_sites.savefig('Figures/Fig3_season_region.pdf',bbox_inches = 'tight')
